Adds PropagationNet GNN layer and makes it optionally usable in existing deterministic models#507
Conversation
- Update test_datasets.py to use ForecasterModule instead of GraphLAM - Update test_plotting.py to use ForecasterModule instead of GraphLAM - Fix interior_mask_bool property shape (1,) -> (N,) for correct loss masking - Fix all_gather_cat to handle single-device runs without incorrect dim collapse
…r hierarchy - Replace opaque argparse.Namespace with explicit keyword arguments in StepPredictor, BaseGraphModel, BaseHiGraphModel, GraphLAM, HiLAM, and HiLAMParallel __init__ methods - Reorder methods in step_predictor.py: forward/expand_to_batch now appear before clamping methods - Update all instantiation sites (train_model.py, test_training.py, test_prediction_model_classes.py) to pass explicit kwargs - HiLAM helper methods (make_same/up/down_gnns) now use self.hidden_dim and self.hidden_layers instead of args parameter Addresses review comments on PR mllam#208.
- Rename border to boundary in Forecaster - Pass Forecaster object to ForecasterModule init instead of Predictor - Remove inline imports in ForecasterModule - Move loss-related pred_std logic fully into ForecasterModule - Delete obsolete test_refactored_hierarchy.py
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
- Add predicts_std property to StepPredictor, Forecaster and ARForecaster so ForecasterModule can query the forecaster instead of taking output_std as a separate constructor argument - Remove output_std parameter from ForecasterModule; use self._forecaster.predicts_std throughout - Move fallback per_var_std logic out of forecast_for_batch into each step method so pred_std is None before fallback, enabling direct None checks instead of hparam checks - Replace len(datastore.boundary_mask) with datastore.num_grid_points in StepPredictor to avoid relying on boundary_mask - Move get_state_feature_weighting and ARForecaster inline imports to module-level imports in forecaster_module.py and train_model.py - Fix statement ordering in StepPredictor.__init__ so register_buffer for grid_static_features appears directly after building the tensor - Replace dict+loop pattern for registering state_mean/state_std buffers with two direct register_buffer calls - Remove all internal Item N checklist references from comments - Remove TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD env var hack; pass weights_only=False explicitly to load_from_checkpoint calls and weights_only=True to torch.load in test_graph_creation.py - Add test_step_predictor_no_static_features to verify models initialise and run correctly when the datastore returns None for static features - Fix graph= -> graph_name= and model.forecaster -> model._forecaster in tests to match current API
…r_batch Makes the forecasting path tolerant to batch-folded execution so that future ensemble generation can fold (S, B) into (S*B) before calling ARForecaster, without any changes to ARForecaster or StepPredictor. Prediction is kept folded through the existing deterministic logging and aggregation paths so all dim assumptions in training_step, validation_step, and test_step remain correct. Unfolding to (*leading, T, N, F) is deferred to ensemble-specific subclasses (e.g. EnsForecasterModule). Adds test_fold_unfold_equivalence to confirm ARForecaster's rollout is rank-transparent under a pre-entry fold.
…stic models - Port PropagationNet as InteractionNet subclass (mean aggr, sender residual in messages, aggregation residual in forward) - Add --vertical_propnets CLI flag to select PropagationNet for grid-mesh and vertical message passing edges - Wire flag through model hierarchy: BaseGraphModel (g2m/m2g), BaseHiGraphModel (mesh init), HiLAM (up GNNs) - Add 13 tests covering unit behavior and backward compatibility
|
@joeloskarsson @sadamov @observingClouds please have a look , if this qualifies as the next step in ensemble prep 😄 . Would be grateful for your feedback ! |
|
@joeloskarsson @sadamov would request your further review on this PR! |
joeloskarsson
left a comment
There was a problem hiding this comment.
Thanks for looking at this, and sorry for being slow with reviews 😅 Shared some first thoughts here on how I think we can best integrate this. Happy to hear input also from others, as there are some non-trivial design choices around this (e.g. how to choose the GNN type for each sub-graph).
| num_past_forcing_steps: int = 1, | ||
| num_future_forcing_steps: int = 1, | ||
| output_std: bool = False, | ||
| vertical_propnets: bool = False, |
There was a problem hiding this comment.
I think we need more fine-grained options for where to use propnets here (g2m, m2g, both). We should also make this more future-proof by not making it a boolean inet/propnet, but rather having the argument be the GNN type to use (as a string? enum? not sure about best design).
There was a problem hiding this comment.
This goes also for other model classes with this argument.
| # Always concatenate to [rec_nodes, send_nodes] for propagation, | ||
| # but only aggregate to rec_nodes | ||
| node_reps = torch.cat((rec_rep, send_rep), dim=-2) | ||
| edge_rep_aggr, edge_diff = self.propagate( | ||
| self.edge_index, x=node_reps, edge_attr=edge_rep | ||
| ) | ||
| rec_diff = self.aggr_mlp( | ||
| torch.cat((rec_rep, edge_rep_aggr), dim=-1) | ||
| ) | ||
|
|
||
| # Residual connections | ||
| rec_rep = edge_rep_aggr + rec_diff # residual is to aggregation | ||
|
|
||
| if self.update_edges: | ||
| edge_rep = edge_rep + edge_diff | ||
| return rec_rep, edge_rep | ||
|
|
||
| return rec_rep |
There was a problem hiding this comment.
There is quite a lot of repeated code from the InteractionNets forward here. Could this be refactored to avoid as much repetition, while keeping some clarity of the differences between the two classes?
Describe your changes
Adds
PropagationNetGNN layer and makes it optionally usable in existing deterministic models, as outlined in #62.It is integrated into the existing model hierarchy from #208 and can be enabled via the
vertical_propnetsflag.Depends on #208.
For changes on top of #208 only, see:
Sir-Sloth-The-Lazy/neural-lam@refactor/model-class-hierarchy-issue-49...refactor/batch-fold-ensemble-prep
Issue Link
Contributes to #62
Type of change
Checklist before requesting a review
pullwith--rebaseoption if possible).Checklist for reviewers
Each PR comes with its own improvements and flaws. The reviewer should check the following:
Author checklist after completed review
reflecting type of change (add section where missing):
Checklist for assignee